#include "levenshtein.h"
#include <tuple>
#include <map>


namespace lp {

	//int subseq_dist(const Functor& e1, const Functor& e2)
	//{
	//	auto is_list = [](const Functor* f){ return f->is_pair() || f->is_empty_list(); };
	//	auto is_seq = [](const Functor* f){ return f->is_seq() || f->is_constant(); };
	//	int d = 0;
	//	for (decltype(e1.arity()) k = 0; k < e1.arity(); ++k) {
	//		const Functor* a1 = e1.arg(k);
	//		const Functor* a2 = e2.arg(k);
	//		if (is_list(a1) && is_list(a2)) {
	//			// Compare using more efficient diff
	//			const int r = subseq_dist(a1->list_begin(),a1->list_end(),a2->list_begin(),a2->list_end());
	//			if (r > 0) d += r;
	//		} else if (is_seq(a1) && is_seq(a2)) { // this case covers constants
	//			// Compare using more efficient diff
	//			const int r = subseq_dist(a1->seq_begin(),a1->seq_end(),a2->seq_begin(),a2->seq_end());
	//			if (r > 0) d += r;
	//		} else {
	//			// Different types => very different
	//			d += 100000;
	//		}
	//	}
	//	return d;
	//}


	int dissimilarity(const lp::Functor& e1, const lp::Functor& e2)
	{
		if (e1.signature() != e2.signature()) {
			throw not_comparable();
		}
		auto is_list = [](const Functor* f){ return f->is_pair() || f->is_empty_list(); };
		auto is_seq = [](const Functor* f){ return f->is_seq() || f->is_constant(); };
		int d = 0;
		for (decltype(e1.arity()) k = 0; k < e1.arity(); ++k) {
			const Functor* a1 = e1.arg(k);
			const Functor* a2 = e2.arg(k);
			if (is_list(a1) && is_list(a2)) {
				// Compare using more efficient diff
				d += diff(a1->list_begin(),a1->list_end(),a2->list_begin(),a2->list_end());
			} else if (is_seq(a1) && is_seq(a2)) { // this case covers constants
				// Compare using more efficient diff
				d += diff(a1->seq_begin(),a1->seq_end(),a2->seq_begin(),a2->seq_end());
			} else {
				//throw not_comparable();
				// Compare using tdist
				const int tmp = tdist(*a1,*a2);
				d += tmp;
				std::cerr << "Distance between " << *a1 << " and " << *a2 << " = " << tmp << "\n";
			}
		}
		return d;
	}


	int tdist(
		int xbeg, int xend,
		int ybeg, int yend,
		const std::vector<const lp::Functor*>& xl,
		const std::vector<const lp::Functor*>& yl,
		std::map<std::tuple<int,int,int,int>,int>& memo,
		int dcost,
		int mcost)
	{

		auto left_most = [&](int i, const std::vector<const lp::Functor*>& l) -> int {
			const lp::Functor* f = l[i];
			while (!f->is_leaf()) f = f->arg_first();
			auto at = l.begin();
			for ( ; *at != f; ++at);
			return at - l.begin();
		};

		if (xbeg > xend) {
			if (ybeg > yend) {
				return 0;
			}
			return yend - ybeg + 1;
		} else if (ybeg > yend) {
			return xend - xbeg + 1;
		} else {
			auto at = memo.find(make_tuple(xbeg,xend,ybeg,yend));
			if (at != memo.end()) {
				//std::cerr << "Using Memo\n";
				return at->second;
			}
			const int leftx = left_most(xend,xl);
			const int lefty = left_most(yend,yl);
			int p1,p2,p3;
			p1 = p2 = p3 = std::numeric_limits<int>::max();
			// If roots match, try equal path as it may be cheaper
			const int match_cost = ( xl[xend]->id() == yl[yend]->id() ? 0 : mcost );
			if (match_cost == 0) {
				p3 = tdist(xbeg,leftx-1,ybeg,lefty-1,xl,yl,memo,dcost,mcost) 
					+ tdist(leftx,xend-1,lefty,yend-1,xl,yl,memo,dcost,mcost)
					+ match_cost; // == 0
				if (p3 <= dcost) {
					memo.insert(std::make_pair(std::make_tuple(xbeg,xend,ybeg,yend),p3));
					return p3; // this is the cheapest way
				}
				// else: do p1 and p2
				p1 = tdist(xbeg,xend-1,ybeg,yend,xl,yl,memo,dcost,mcost) + dcost;
				if (p1 == dcost) {
					memo.insert(std::make_pair(std::make_tuple(xbeg,xend,ybeg,yend),p1));
					return p1;
				}
				p2 = tdist(xbeg,xend,ybeg,yend-1,xl,yl,memo,dcost,mcost) + dcost;
			} else {
				// Try p1 and p2 first
				p1 = tdist(xbeg,xend-1,ybeg,yend,xl,yl,memo,dcost,mcost) + dcost;
				if (p1 == dcost && p1 <= match_cost) {
					memo.insert(std::make_pair(std::make_tuple(xbeg,xend,ybeg,yend),p1));
					return p1;
				}
				p2 = tdist(xbeg,xend,ybeg,yend-1,xl,yl,memo,dcost,mcost) + dcost;
				if (p2 == dcost && p2 <= match_cost) {
					memo.insert(std::make_pair(std::make_tuple(xbeg,xend,ybeg,yend),p2));
					return p2;
				}
				p3 = tdist(xbeg,leftx-1,ybeg,lefty-1,xl,yl,memo,dcost,mcost) 
					+ tdist(leftx,xend-1,lefty,yend-1,xl,yl,memo,dcost,mcost) 
					+ match_cost;
			}
			const auto m = std::min( std::min(p1,p2), p3);
			memo.insert(std::make_pair(std::make_tuple(xbeg,xend,ybeg,yend),m));
			return m;
		}
	}


	int tdist(const lp::Functor& f, const lp::Functor& g)
	{
		std::vector<const lp::Functor*> xl,yl;
		std::map<std::tuple<int,int,int,int>,int> memo;
		for (auto i = f.post_begin(); i != f.post_end(); ++i) xl.push_back(&*i);
		for (auto i = g.post_begin(); i != g.post_end(); ++i) yl.push_back(&*i);
		return tdist(0,int(xl.size())-1,0,int(yl.size())-1,xl,yl,memo,1,2);
	}

}

